'''
first stage test
'''


import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn import svm
from scipy.misc import imresize
import random
from sklearn import model_selection
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from PIL import Image
import pickle
import math
from skimage.feature import greycomatrix, greycoprops
import math
from skimage.feature import greycomatrix, greycoprops
from skimage.morphology import *
from skimage.feature import hog
import multiprocessing as mp
import itertools
import re
import warnings
import threading
import functions as fc
import sys
warnings.filterwarnings("ignore")




class TestWhole:

    def __init__(self,classifier,pos_root,stride,img,pca,use_color,scaler):
        self.classifier=classifier
        self.pos_root=pos_root
        self.stride=stride
        self.patchSize=64
        self.patchs=[]
        self.count=0
        self.big_img=img
        self.use_color=use_color
        self.lock = mp.Lock()
        self.pca=pca
        self.scaler=scaler
        print('use_color:{}'.format(self.use_color))
        print('pca:',self.pca)
        print('scaler',self.scaler)
        print('saving root:',self.pos_root)
        _=input('press enter to continue!')

        print ('clearing folder...')
        # clear old images in folder
        for file_name in os.listdir(self.pos_root):
            file_path=os.path.join(pos_root,file_name)
            if os.path.isfile(file_path):
                os.unlink(file_path)
        print ('finish clearing...')



    def TestPatch(self,patch_RGB,global_w,global_h):
        if self.use_color:
            feature_per_patch=fc.hog_color(patch_RGB,10)
            feature_per_patch=self.scaler.transform(feature_per_patch)
        else:
            feature_per_patch=fc.hog_(patch_RGB)

        if self.pca:
            feature_per_patch=self.pca.transform(feature_per_patch)
        prediction=self.classifier.predict(feature_per_patch)[0]

        if prediction==1:
            lock.acquire()
            cv2.imwrite(os.path.join(self.pos_root,str(global_h)+'_'+str(global_w)+'.jpg'),patch_RGB)
            lock.release()

    def Run(self,img_class,start,end):
        img=img_class.data

        h=img.shape[0]
        w=img.shape[1]
        for i in range(0, h, self.stride):
            if i + self.patchSize > h:
                break
            for j in range(0, w, self.stride):
                print('height: %d/%d, width %d/%d' % (i, h,j,w))

                if j + self.patchSize > w:
                    break
                global_w=j+start+corner[0]
                global_h=i+corner[2]
                patch = img[i : i + self.patchSize, j : j + self.patchSize, :].copy()
                self.TestPatch(patch,global_w,global_h)

    def Mp(self,img:np.array):
        img_sub=[]
        CPUnum=32
        end=0
        starts=[]
        ends=[]
        for i in range(CPUnum):
            length=int(img.shape[1]/CPUnum)
            start=end-self.patchSize+self.stride
            if (start<0):
                start=0
            end=start+length
            if(i==CPUnum-1):
                end=img.shape[1]
            img_sb=img[:,start:end].copy()
            img_sb=image(img_sb)
            img_sub.append(img_sb)
            starts.append(start)
            ends.append(end)
        procs=[]
        for i in range(CPUnum):
            proc=mp.Process(target=self.Run, args=(img_sub[i],starts[i],ends[i]))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()

class image:
    def __init__(self,img):
        self.data=img

if __name__=='__main__':
    use_color=True

    lock = mp.Lock()
    stride=8
    pos_root='/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/hog_color_pca_new/standard/results'
    corner=fc.get_corner()
    image_path='/media/deeplearning/DATA/system/wbw/GDA_project/data/imageTest/pic_1.jpg'
    img = cv2.imread(image_path)
    img=img[corner[2]:corner[3],corner[0]:corner[1],:]

    if use_color:
        with open('/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/hog_color_pca_new/standard/svm.p','rb') as f:
            clf=pickle.load(f)
        with open('/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/hog_color_pca_new/standard/pca.p','rb') as f:
            pca=pickle.load(f)
        with open('/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/hog_color_pca_new/standard/scaler.p','rb') as f:
            scaler=pickle.load(f)
    else:
        with open('/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/comparison_methods/hog_svm/svm.p','rb') as f:
            clf=pickle.load(f)
        with open('/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/comparison_methods/hog_svm/pca.p','rb') as f:
            pca=pickle.load(f)


    test=TestWhole(clf,pos_root,stride,img,pca,use_color,scaler)
    test.Mp(img)
